# -*- coding: utf-8 -*-
# CS246 - Colab 2
## Frequent Pattern Mining in Spark

### Setup

import pyspark
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf

"""Let's initialize the Spark context."""

# create the session
conf = SparkConf().set("spark.ui.port", "4050")

# create the context
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession.builder.getOrCreate()

"""You can easily check the current version and get the link of the web interface. In the Spark UI, you can monitor the progress of your job and debug the performance bottlenecks (if your Colab is running with a **local runtime**)."""


"""If you are running this Colab on the Google hosted runtime, the cell below will create a *ngrok* tunnel which will allow you to still check the Spark UI."""

"""### Your task

If you run successfully the setup stage, you are ready to work with the **3 Million Instacart Orders** dataset. In case you want to read more about it, check the [official Instacart blog post](https://tech.instacart.com/3-million-instacart-orders-open-sourced-d40d29ead6f2) about it, a concise [schema description](https://gist.github.com/jeremystan/c3b39d947d9b88b3ccff3147dbcf6c6b) of the dataset, and the [download page](https://www.instacart.com/datasets/grocery-shopping-2017).

In this Colab, we will be working only with a small training dataset (~131K orders) to perform fast Frequent Pattern Mining with the FP-Growth algorithm.
"""

products = spark.read.csv('../data/products.csv', header=True, inferSchema=True)
orders = spark.read.csv('../data/order_products__train.csv', header=True, inferSchema=True)

products.printSchema()

orders.printSchema()

"""Use the Spark Dataframe API to join 'products' and 'orders', so that you will be able to see the product names in each transaction (and not only their ids).  Then, group by the orders by 'order_id' to obtain one row per basket (i.e., set of products purchased together by one customer)."""

# YOUR CODE HERE
joined= orders.join(products, orders.product_id == products.product_id)
joined.printSchema()
orders_agged=orders.groupBy(orders.order_id).agg(collect_list('product_id').alias('product_id_list')).select(['order_id','product_id_list'])
orders_agged.take(10)

"""In this Colab we will explore [MLlib](https://spark.apache.org/mllib/), Apache Spark's scalable machine learning library. Specifically, you can use its implementation of the [FP-Growth](https://spark.apache.org/docs/latest/ml-frequent-pattern-mining.html#fp-growth) algorithm to perform efficiently Frequent Pattern Mining in Spark.
Use the Python example in the documentation, and train a model with 

```minSupport=0.01``` and ```minConfidence=0.5```
"""

# YOUR CODE HERE
from pyspark.ml.fpm import FPGrowth

fpGrowth = FPGrowth(itemsCol="product_id_list", minSupport=0.01, minConfidence=0.5)
model = fpGrowth.fit(orders_agged)

"""Compute how many frequent itemsets and association rules were generated by running FP-growth."""

# YOUR CODE HERE
# Display frequent itemsets.
print('frequent itemsets:')
print(model.freqItemsets.count())# 120
model.freqItemsets.sort(desc('freq')).show(10)

# Display generated association rules.
print('associaton rules:')
print(model.associationRules.count()) # 0
model.associationRules.sort(desc('confidence')).show(10)

# transform examines the input items against all the association rules and summarize the
# consequents as prediction
print('predictions:')
model.transform(orders_agged).show(10)

"""Now retrain the FP-growth model changing only 
```minsupport=0.001``` 
and compute how many frequent itemsets and association rules were generated.
"""

# YOUR CODE HERE
fpGrowth = FPGrowth(itemsCol="product_id_list", minSupport=0.001, minConfidence=0.5)
model = fpGrowth.fit(orders_agged)

# Display frequent itemsets.
print('frequent itemsets:')
print(model.freqItemsets.count())# 4444
model.freqItemsets.sort(desc('freq')).show(10)

# Display generated association rules.
print('associaton rules:')
print(model.associationRules.count()) # 11
model.associationRules.sort(desc('confidence')).show(10)

# transform examines the input items against all the association rules and summarize the
# consequents as prediction
print('predictions:')
model.transform(orders_agged).show(10)
